Skip to content

[WIP] feat: support multi-B weight tensors (DWDP) in CuTe DSL NVFP4 MoE#3041

Open
yhyang201 wants to merge 3 commits intoflashinfer-ai:mainfrom
yhyang201:feat/cute-dsl-moe-multi-b-weights
Open

[WIP] feat: support multi-B weight tensors (DWDP) in CuTe DSL NVFP4 MoE#3041
yhyang201 wants to merge 3 commits intoflashinfer-ai:mainfrom
yhyang201:feat/cute-dsl-moe-multi-b-weights

Conversation

@yhyang201
Copy link
Copy Markdown

@yhyang201 yhyang201 commented Apr 13, 2026

Extend the Blackwell NVFP4 fused MoE (gather SwiGLU + finalize) kernels and their Python wrappers to accept w1/w2 weight, weight_sf and alpha as either a single tensor or a list of up to 4 tensors split along the expert dimension. The compiled kernel is specialized per multi-B config via b_tensor_l_sizes, with kernel-side branching selecting the right B tensor from the runtime expert index.

Also adds end-to-end tests verifying multi-B results match the single stacked-tensor baseline.

📌 Description

WIP — opening early for visibility / review feedback. Not ready to merge: perf parity vs. TRT-LLM still TBD, and I'm still sweeping the unit tests.

During FlashInfer's port of the TRT-LLM gather+SwiGLU kernel to CuTe DSL Python, the tile_size=256 path (use_2cta_instrs=True, where two CTAs cooperate on a larger MMA operation) produces numerically incorrect results — the kernel runs but gives wrong answers. An NVIDIA engineer discovered this in PR #2775 and disabled it as a workaround, leaving only tile_size=128. Since TRT-LLM's original kernel works correctly with tile_size=256, this is a bug introduced during the porting process. It doesn't affect DWDP functionality, but it halves the autotuner's tactic search space and may cost some performance on large-batch workloads.

Summary

Updates the CuTe DSL NVFP4 MoE kernels to accept weights as a list of tensors split along the expert dimension (up to 4), in addition to the existing single-tensor layout. This lands the DWDP (Distributed Weight Data Parallelism) support that the CUTLASS/TRT-LLM side already has.

  • w1_weight / w1_weight_sf / w1_alpha and the w2_* counterparts now accept Union[Tensor, List[Tensor]]
  • Kernel is specialized per multi-B config via b_tensor_l_sizes; the right B tensor is selected from the runtime expert index on the kernel side
  • Wrapper / functional / tuner paths all updated for the new layout
  • Adds end-to-end tests verifying multi-B results match the single stacked-tensor baseline

Ports the approach from NVIDIA/TensorRT-LLM#12136.

🔍 Related Issues

#3036

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features
    • Support for splitting expert weights/scale/alpha across up to 4 tensors; APIs accept either a single tensor or a list with automatic expert-dimension handling and backward compatibility.
  • Validation
    • Added input validation for multi-tensor lists (non-empty, max 4, aligned lengths) and adjusted dimension derivation when lists are used.
  • Documentation
    • Public docstrings expanded to describe the single-tensor-or-list convention and expert-splitting behavior.
  • Tests
    • Added tests for multi-tensor partitions, backward compatibility, and runtime/autotune execution.

Extend the Blackwell NVFP4 fused MoE (gather SwiGLU + finalize) kernels
and their Python wrappers to accept w1/w2 weight, weight_sf and alpha as
either a single tensor or a list of up to 4 tensors split along the
expert dimension. The compiled kernel is specialized per multi-B config
via b_tensor_l_sizes, with kernel-side branching selecting the right B
tensor from the runtime expert index.

Also adds end-to-end tests verifying multi-B results match the single
stacked-tensor baseline.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 13, 2026

📝 Walkthrough

Walkthrough

Added multi-B (multiple B-tensor / DWDP) support across CuteDSL MoE gather/finalize kernels, wrappers, and public NVFP4 APIs. Interfaces now accept single tensors or tuples/lists for B, SFB, and alpha; kernels/wrappers build per-B tuples, route TMA/loading per B via expert-index offsets, and update compilation/runtime argument plumbing.

Changes

Cohort / File(s) Summary
Kernels — gather fusion
flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
Add MAX_B_TENSORS=4, support b_tensor_l_sizes, normalize b/sfb/alpha to tuples, build per‑B TMA atoms/tensors, replicate tiling/partitioning per possible B count, select loads by expert_idx + b_tensor_l_offsets, update wrapper to accept tuple ptrs.
Kernels — finalize fusion
flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
Same multi‑B changes as gather: tuple-based tma_atoms/tma_tensors, per‑B partitioning guarded by const_expr(num_b_tensors >= N), runtime selection of B/SFB/alpha by offsets, wrapper updated to tuple pointers and constructs b_tuple/b_sf_tuple/alpha_tuple.
High-level wrappers / API
flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py, .../blockscaled_contiguous_grouped_gemm_finalize_fusion.py
Public functions accept b, b_scale, alpha as single tensor or list; validate multi‑B lists (<=4), build pointer tuples, compute b_tensor_l_sizes, include sizes in compile cache key, remove num_experts from wrapper args and recompute at runtime.
Fused MoE entrypoints
flashinfer/fused_moe/cute_dsl/fused_moe.py
Public NVFP4 APIs and wrapper CuteDslMoEWrapper.run accept Union[Tensor, List[Tensor]] for weight/scale/alpha inputs; derive hidden/output sizes from first list element when lists provided; docstrings updated.
Tuner
flashinfer/fused_moe/cute_dsl/tuner.py
Tactic validation updated to accept w1_weight as tensor or list: computes num_local_experts as sum over list or previous single-tensor logic; intermediate size from first tensor for lists.
Tests
tests/moe/test_cute_dsl_fused_moe.py
New TestMultiBTensor suite validating multi‑B behavior (2/3/4 partitions, uneven splits), backward compatibility with 1‑element lists, wrapper acceptance, autotune run, and accuracy checks.
Cross-cutting small edits
...
Typing imports extended to Union; compile argument refactor to build/splat compile_args; cache keys updated to include b_tensor_l_sizes.

Sequence Diagram(s)

sequenceDiagram
    participant PyAPI as Python API / Caller
    participant Wrapper as NVFP4 Wrapper
    participant Kernel as Compiled CuTe Kernel
    participant DeviceMem as Device Memory / TMA / Shared Buffers

    PyAPI->>Wrapper: normalize inputs (tensor or list) -> build b_tuple, sfb_tuple, alpha_tuple, b_tensor_l_sizes
    Wrapper->>Wrapper: compute b_tensor_l_offsets, total_l, compile/lookup kernel (with b_tensor_l_sizes)
    Wrapper->>Kernel: invoke compiled kernel with tuple ptrs and metadata
    Kernel->>DeviceMem: partition TMA per B (const_expr branches) -> load appropriate B/SFB per expert_idx using offsets
    DeviceMem-->>Kernel: return loaded tiles into shared buffers
    Kernel->>DeviceMem: run grouped GEMM + epilogue (select alpha via tuple+offset)
    Kernel-->>Wrapper: write output to destination
    Wrapper-->>PyAPI: return result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

cute-dsl, run-ci

Suggested reviewers

  • aleozlx
  • yzh119
  • samuellees
  • IwakuraRein
  • jiahanc
  • nv-yunzheq
  • bkryu
  • jimmyzho

Poem

"A rabbit hopped in code so spry,
Tuples of B like clouds in sky,
Offsets mapped with nimble feet,
Kernels pick the slice to meet,
Hops and flops — a multisliced pie 🐰✨"

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 inconclusive)

Check name Status Explanation Resolution
Description check ❓ Inconclusive The description provides detailed context about the DWDP support, related work, and notes WIP status, but the required template checklist items remain unchecked and incomplete. Complete the pre-commit checks and tests checklist items in the template. Either check them off if completed or explain why they are not applicable.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main feature: adding multi-B weight tensor (DWDP) support to CuTe DSL NVFP4 MoE, which aligns with the changeset.
Docstring Coverage ✅ Passed Docstring coverage is 90.00% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for multiple B weight tensors (Distributed Weight Data Parallelism) in the Blackwell blockscaled MoE kernels. The changes enable the selection of B tensors and alpha values at runtime based on expert indices. Review feedback identified several critical issues: potential out-of-bounds accesses when retrieving alpha values in both gather and finalize kernels, and logic errors in the wrapper functions where the default single-B case (when b_tensor_l_sizes is None) leads to TypeError or incorrect layout dimensions due to offset padding.

Comment on lines +2977 to +3016
alpha_val = alpha_tuple[0][expert_idx - self.b_tensor_l_offsets[0]]
if cutlass.const_expr(self.num_b_tensors == 1):
pass # Already initialized above
elif cutlass.const_expr(self.num_b_tensors == 2):
if expert_idx >= self.b_tensor_l_offsets[1]:
alpha_val = alpha_tuple[1][
expert_idx - self.b_tensor_l_offsets[1]
]
elif cutlass.const_expr(self.num_b_tensors == 3):
if (
expert_idx >= self.b_tensor_l_offsets[1]
and expert_idx < self.b_tensor_l_offsets[2]
):
alpha_val = alpha_tuple[1][
expert_idx - self.b_tensor_l_offsets[1]
]
elif expert_idx >= self.b_tensor_l_offsets[2]:
alpha_val = alpha_tuple[2][
expert_idx - self.b_tensor_l_offsets[2]
]
else:
# 4 B tensors
if (
expert_idx >= self.b_tensor_l_offsets[1]
and expert_idx < self.b_tensor_l_offsets[2]
):
alpha_val = alpha_tuple[1][
expert_idx - self.b_tensor_l_offsets[1]
]
elif (
expert_idx >= self.b_tensor_l_offsets[2]
and expert_idx < self.b_tensor_l_offsets[3]
):
alpha_val = alpha_tuple[2][
expert_idx - self.b_tensor_l_offsets[2]
]
elif expert_idx >= self.b_tensor_l_offsets[3]:
alpha_val = alpha_tuple[3][
expert_idx - self.b_tensor_l_offsets[3]
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The initial assignment to alpha_val at line 2977 uses alpha_tuple[0] with an index that could be out of bounds if expert_idx belongs to a subsequent tensor (i.e., expert_idx >= self.b_tensor_l_offsets[1]). While alpha_tuple[0] is a cute.Tensor and indexing might just perform pointer arithmetic, it is safer and more correct to guard the access within the num_b_tensors and expert_idx branches to ensure only the valid tensor for the current expert is accessed.

                # Select alpha from correct tensor based on expert_idx
                if cutlass.const_expr(self.num_b_tensors == 1):
                    alpha_val = alpha_tuple[0][expert_idx]
                elif cutlass.const_expr(self.num_b_tensors == 2):
                    if expert_idx < self.b_tensor_l_offsets[1]:
                        alpha_val = alpha_tuple[0][expert_idx]
                    else:
                        alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]]
                elif cutlass.const_expr(self.num_b_tensors == 3):
                    if expert_idx < self.b_tensor_l_offsets[1]:
                        alpha_val = alpha_tuple[0][expert_idx]
                    elif expert_idx < self.b_tensor_l_offsets[2]:
                        alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]]
                    else:
                        alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]]
                else:
                    # 4 B tensors
                    if expert_idx < self.b_tensor_l_offsets[1]:
                        alpha_val = alpha_tuple[0][expert_idx]
                    elif expert_idx < self.b_tensor_l_offsets[2]:
                        alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]]
                    elif expert_idx < self.b_tensor_l_offsets[3]:
                        alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]]
                    else:
                        alpha_val = alpha_tuple[3][expert_idx - self.b_tensor_l_offsets[3]]

scale_k = k // scaling_vector_size
interm_size = n // 2
num_tiles = m // tile_size
total_l = self.b_tensor_l_offsets[self.num_b_tensors]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The calculation of total_l using self.b_tensor_l_offsets[self.num_b_tensors] is problematic when b_tensor_l_sizes is None (the generic single-B case). In that case, self.num_b_tensors is 1 and self.b_tensor_l_offsets[1] is padded with 2**30 (line 540), leading to an incorrect and massive dimension for the c_sf layout. Since l was removed from the wrapper signature, there is no runtime fallback for the expert count. Consider restoring l to the signature or ensuring b_tensor_l_sizes is always a valid tuple in __init__.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines +2392 to +2431
alpha_val = alpha_tuple[0][expert_idx - self.b_tensor_l_offsets[0]]
if cutlass.const_expr(self.num_b_tensors == 1):
pass # Already initialized above
elif cutlass.const_expr(self.num_b_tensors == 2):
if expert_idx >= self.b_tensor_l_offsets[1]:
alpha_val = alpha_tuple[1][
expert_idx - self.b_tensor_l_offsets[1]
]
elif cutlass.const_expr(self.num_b_tensors == 3):
if (
expert_idx >= self.b_tensor_l_offsets[1]
and expert_idx < self.b_tensor_l_offsets[2]
):
alpha_val = alpha_tuple[1][
expert_idx - self.b_tensor_l_offsets[1]
]
elif expert_idx >= self.b_tensor_l_offsets[2]:
alpha_val = alpha_tuple[2][
expert_idx - self.b_tensor_l_offsets[2]
]
else:
# 4 B tensors
if (
expert_idx >= self.b_tensor_l_offsets[1]
and expert_idx < self.b_tensor_l_offsets[2]
):
alpha_val = alpha_tuple[1][
expert_idx - self.b_tensor_l_offsets[1]
]
elif (
expert_idx >= self.b_tensor_l_offsets[2]
and expert_idx < self.b_tensor_l_offsets[3]
):
alpha_val = alpha_tuple[2][
expert_idx - self.b_tensor_l_offsets[2]
]
elif expert_idx >= self.b_tensor_l_offsets[3]:
alpha_val = alpha_tuple[3][
expert_idx - self.b_tensor_l_offsets[3]
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the gather kernel, the initial assignment to alpha_val at line 2392 performs a potentially out-of-bounds access on alpha_tuple[0] when expert_idx belongs to a later tensor. The access should be moved inside the conditional branches.

                # Select alpha from correct tensor based on expert_idx
                if cutlass.const_expr(self.num_b_tensors == 1):
                    alpha_val = alpha_tuple[0][expert_idx]
                elif cutlass.const_expr(self.num_b_tensors == 2):
                    if expert_idx < self.b_tensor_l_offsets[1]:
                        alpha_val = alpha_tuple[0][expert_idx]
                    else:
                        alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]]
                elif cutlass.const_expr(self.num_b_tensors == 3):
                    if expert_idx < self.b_tensor_l_offsets[1]:
                        alpha_val = alpha_tuple[0][expert_idx]
                    elif expert_idx < self.b_tensor_l_offsets[2]:
                        alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]]
                    else:
                        alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]]
                else:
                    # 4 B tensors
                    if expert_idx < self.b_tensor_l_offsets[1]:
                        alpha_val = alpha_tuple[0][expert_idx]
                    elif expert_idx < self.b_tensor_l_offsets[2]:
                        alpha_val = alpha_tuple[1][expert_idx - self.b_tensor_l_offsets[1]]
                    elif expert_idx < self.b_tensor_l_offsets[3]:
                        alpha_val = alpha_tuple[2][expert_idx - self.b_tensor_l_offsets[2]]
                    else:
                        alpha_val = alpha_tuple[3][expert_idx - self.b_tensor_l_offsets[3]]

alpha = cute.make_tensor(alpha_ptr, layout=cute.make_layout((l,)))

# Create B and alpha tensors using const_expr conditions
l_0 = self.b_tensor_l_sizes[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Accessing self.b_tensor_l_sizes[0] will raise a TypeError if b_tensor_l_sizes is None (the generic single-B case). The wrapper should handle the case where expert sizes are not provided at initialization, likely by restoring the l parameter to the signature.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`:
- Around line 2975-3016: The initial unconditional assignment to alpha_val from
alpha_tuple[0] can index out-of-bounds for expert_idx >=
self.b_tensor_l_offsets[1]; update the selection to mirror the B-tensor
selection pattern used elsewhere: remove the unconditional alpha_tuple[0] read
and implement explicit range checks against self.b_tensor_l_offsets for each
branch of cutlass.const_expr(self.num_b_tensors) so alpha_val is only read from
alpha_tuple[i] when expert_idx is within that tensor's [start, end) range; use
the same ordering and guards involving expert_idx, self.b_tensor_l_offsets,
num_b_tensors, alpha_tuple and alpha_val as in the B-tensor selection logic to
ensure safe indexing.

In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 3192-3194: wrapper() currently indexes self.b_tensor_l_sizes[0]
(l_0) without ensuring b_tensor_l_sizes was provided in __init__, causing a
NoneType subscript; add an explicit guard at the start of wrapper() that checks
if self.b_tensor_l_sizes is None and raises a clear ValueError explaining that
b_tensor_l_sizes must be set for multi-B fused kernels (or alternatively make
b_tensor_l_sizes a required constructor parameter in __init__); update
references around l_0 and alpha_0 to rely on this validation so the subsequent
cute.make_tensor(...) call is safe.

In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 378-381: The code now accepts list inputs for b, b_scale, and
alpha but lacks validation: ensure b_list, b_scale_list, and alpha_list are
non-empty, have identical lengths, and that every corresponding tensor split
agrees on the non-expert dimensions (shapes except the expert-split dimension)
before any indexing or kernel compilation (e.g., before using b_list[0] or
passing these lists into the compiled kernel path). Add explicit checks that
raise a clear error if any list is empty, if len(b_list) != len(b_scale_list) !=
len(alpha_list), or if any pairwise tensor shape mismatch exists on non-expert
dims; apply the same validation logic around the other normalization sites noted
(the blocks around the other occurrences you flagged: the b/b_scale/alpha
normalization at the later sections).

In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 1061-1062: Replace the file-local test gate decorator
`sm100_required` with the repo-standard check
`flashinfer.utils.is_sm100a_supported()` on the new test class so it uses the
canonical GPU support helper; locate the class decorated with
`@cute_dsl_available` and `@sm100_required`, remove `@sm100_required` and apply
the skip/require helper that calls `flashinfer.utils.is_sm100a_supported()` (or
the equivalent test-skip decorator that invokes it) so the test skips correctly
on unsupported SM100a devices instead of relying on the local `sm100_required`
function.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: f08ff5cf-440b-421f-b80f-4ac6097cb798

📥 Commits

Reviewing files that changed from the base of the PR and between b75740d and 038bf93.

📒 Files selected for processing (7)
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
  • flashinfer/fused_moe/cute_dsl/fused_moe.py
  • flashinfer/fused_moe/cute_dsl/tuner.py
  • tests/moe/test_cute_dsl_fused_moe.py

Comment on lines +378 to +381
# Normalize to lists for multi-B support
b_list = [b] if isinstance(b, torch.Tensor) else b
b_scale_list = [b_scale] if isinstance(b_scale, torch.Tensor) else b_scale
alpha_list = [alpha] if isinstance(alpha, torch.Tensor) else alpha
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate the new multi-B inputs before indexing and compiling.

This path accepts lists now, but it never checks that they are non-empty, that b/b_scale/alpha have the same number of splits, or that every split agrees on the non-expert dimensions. Right now [] fails at b_list[0], and mismatched split counts/shapes fall through to obscure tuple-index/layout errors in the compiled kernel path.

🧩 Suggested validation
     b_list = [b] if isinstance(b, torch.Tensor) else b
     b_scale_list = [b_scale] if isinstance(b_scale, torch.Tensor) else b_scale
     alpha_list = [alpha] if isinstance(alpha, torch.Tensor) else alpha
+
+    if not b_list:
+        raise ValueError("b must be a tensor or a non-empty list of tensors")
+    if not (len(b_list) == len(b_scale_list) == len(alpha_list)):
+        raise ValueError("b, b_scale, and alpha must use the same number of splits")
+    if len(b_list) > 4:
+        raise ValueError("at most 4 B tensors are supported")
+
+    ref_n = b_list[0].shape[1]
+    ref_packed_k = b_list[0].shape[2]
+    for i, (bi, bsi, ai) in enumerate(zip(b_list, b_scale_list, alpha_list)):
+        if bi.shape[1:] != (ref_n, ref_packed_k):
+            raise ValueError(f"split {i} has inconsistent B shape: {tuple(bi.shape)}")
+        if bsi.shape[-1] != bi.shape[0]:
+            raise ValueError(
+                f"split {i} has inconsistent B-scale expert dim: {bsi.shape[-1]} != {bi.shape[0]}"
+            )
+        if ai.numel() != bi.shape[0]:
+            raise ValueError(
+                f"split {i} has inconsistent alpha length: {ai.numel()} != {bi.shape[0]}"
+            )

Also applies to: 385-390, 456-486

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 378 - 381, The code now accepts list inputs for b, b_scale, and
alpha but lacks validation: ensure b_list, b_scale_list, and alpha_list are
non-empty, have identical lengths, and that every corresponding tensor split
agrees on the non-expert dimensions (shapes except the expert-split dimension)
before any indexing or kernel compilation (e.g., before using b_list[0] or
passing these lists into the compiled kernel path). Add explicit checks that
raise a clear error if any list is empty, if len(b_list) != len(b_scale_list) !=
len(alpha_list), or if any pairwise tensor shape mismatch exists on non-expert
dims; apply the same validation logic around the other normalization sites noted
(the blocks around the other occurrences you flagged: the b/b_scale/alpha
normalization at the later sections).

Comment on lines +1061 to +1062
@cute_dsl_available
@sm100_required
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use the repo-standard SM100 skip helper for this new coverage.

These tests still rely on the file-local sm100_required gate, which only checks props.major == 10. That can enable the new multi-B cases on unsupported 10.x parts and create spurious failures. Please switch the new class to flashinfer.utils.is_sm100a_supported() instead.

As per coding guidelines: "Use flashinfer.utils functions (get_compute_capability(), is_sm90a_supported(), is_sm100a_supported()) to skip tests on unsupported GPU architectures"

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_cute_dsl_fused_moe.py` around lines 1061 - 1062, Replace the
file-local test gate decorator `sm100_required` with the repo-standard check
`flashinfer.utils.is_sm100a_supported()` on the new test class so it uses the
canonical GPU support helper; locate the class decorated with
`@cute_dsl_available` and `@sm100_required`, remove `@sm100_required` and apply
the skip/require helper that calls `flashinfer.utils.is_sm100a_supported()` (or
the equivalent test-skip decorator that invokes it) so the test skips correctly
on unsupported SM100a devices instead of relying on the local `sm100_required`
function.

- Safe alpha indexing with pre-initialization before const_expr branches
- NoneType guard: raise ValueError when b_tensor_l_sizes=None
- Input validation for multi-B weight lists (empty, max 4, length match)
- Fix test imports to use top-level flashinfer module

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
tests/moe/test_cute_dsl_fused_moe.py (1)

1420-1421: Inconsistent import pattern for autotune.

Line 1420 uses from flashinfer.autotuner import autotune, but the rest of the file (e.g., line 408) uses from flashinfer import autotune. For consistency within this file, prefer the top-level import.

♻️ Proposed fix for import consistency
-        from flashinfer.autotuner import autotune
+        from flashinfer import autotune
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/moe/test_cute_dsl_fused_moe.py` around lines 1420 - 1421, The file
mixes import styles for autotune; change the line that reads "from
flashinfer.autotuner import autotune" to the top-level form "from flashinfer
import autotune" so imports are consistent with other uses in this test (e.g.,
the earlier import at line ~408); leave the import of cute_dsl_fused_moe_nvfp4
unchanged.
flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py (1)

463-465: Use ValueError instead of assert for input validation consistency.

Lines 458-462 correctly use ValueError for the None check, but lines 463-465 use assert for the length check. Assertions can be disabled with Python's -O flag, making this validation bypassable in optimized builds.

♻️ Proposed fix for consistent error handling
-        assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, (
-            f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}"
-        )
+        if len(b_tensor_l_sizes) > self.MAX_B_TENSORS:
+            raise ValueError(
+                f"Max {self.MAX_B_TENSORS} B tensors, got {len(b_tensor_l_sizes)}"
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`
around lines 463 - 465, Replace the runtime assertion with a ValueError to
ensure input validation cannot be disabled: where the code currently does
"assert len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ..." raise a ValueError
with a descriptive message (referencing b_tensor_l_sizes and self.MAX_B_TENSORS)
so the length check always executes (e.g., in the same scope as the existing
None check that uses ValueError).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`:
- Around line 4063-4067: The SFC tensor c_sf is being created with its L
dimension set to total_l which mismatches the backing allocation; change the
layout passed to cute.make_ordered_layout when creating c_sf (in the c_sf =
cute.make_tensor(...) call using c_sf_ptr) so the last dimension is 1 instead of
total_l — i.e. use (32, 4, m // 128, 4, interm_size // (scaling_vector_size *
4), 1) with the same order=(2, 1, 4, 0, 3, 5) so the tensor’s L dimension
remains 1 while keeping the same ordering.

In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`:
- Around line 418-439: Loop over each split in b_list (using enumerate) and
validate per-split properties instead of only checking the first element: assert
each bi (from b_list) is non-empty, on CUDA, has the same second-dimension as
b_list[0] (bi.shape[1] == n), and that bi.size(0) matches the length of the
corresponding alpha_list[i]; also assert corresponding b_scale_list[i] is
present and on CUDA. Update validations around b_list, b_scale_list, alpha_list,
a, n and k (references: b_list, b_scale_list, alpha_list, a, n, k, num_experts)
so the kernel never receives an empty split, a device-mismatched tensor, or
mismatched per-split dimensions.

---

Nitpick comments:
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py`:
- Around line 463-465: Replace the runtime assertion with a ValueError to ensure
input validation cannot be disabled: where the code currently does "assert
len(b_tensor_l_sizes) <= self.MAX_B_TENSORS, ..." raise a ValueError with a
descriptive message (referencing b_tensor_l_sizes and self.MAX_B_TENSORS) so the
length check always executes (e.g., in the same scope as the existing None check
that uses ValueError).

In `@tests/moe/test_cute_dsl_fused_moe.py`:
- Around line 1420-1421: The file mixes import styles for autotune; change the
line that reads "from flashinfer.autotuner import autotune" to the top-level
form "from flashinfer import autotune" so imports are consistent with other uses
in this test (e.g., the earlier import at line ~408); leave the import of
cute_dsl_fused_moe_nvfp4 unchanged.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 99f12819-35d9-4237-9b2e-4093fe716fc2

📥 Commits

Reviewing files that changed from the base of the PR and between 038bf93 and 9edbe9b.

📒 Files selected for processing (5)
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py
  • tests/moe/test_cute_dsl_fused_moe.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_grouped_gemm_finalize_fusion.py

Comment on lines 4063 to 4067
c_sf = cute.make_tensor(
c_sf_ptr,
layout=cute.make_ordered_layout(
(32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), l),
(32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), total_l),
order=(2, 1, 4, 0, 3, 5),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Keep the SFC tensor's L dimension at 1.

out_scale is still allocated by the public wrapper as (..., 1), but this wrapper now reinterprets the same buffer as (..., total_l). On multi-B FP4 paths that gives CuTe a different layout than the backing allocation, and the output tensor still has flattened M x N x 1 semantics here.

Proposed fix
         c_sf = cute.make_tensor(
             c_sf_ptr,
             layout=cute.make_ordered_layout(
-                (32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), total_l),
+                (32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), 1),
                 order=(2, 1, 4, 0, 3, 5),
             ),
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
c_sf = cute.make_tensor(
c_sf_ptr,
layout=cute.make_ordered_layout(
(32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), l),
(32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), total_l),
order=(2, 1, 4, 0, 3, 5),
c_sf = cute.make_tensor(
c_sf_ptr,
layout=cute.make_ordered_layout(
(32, 4, m // 128, 4, interm_size // (scaling_vector_size * 4), 1),
order=(2, 1, 4, 0, 3, 5),
),
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`
around lines 4063 - 4067, The SFC tensor c_sf is being created with its L
dimension set to total_l which mismatches the backing allocation; change the
layout passed to cute.make_ordered_layout when creating c_sf (in the c_sf =
cute.make_tensor(...) call using c_sf_ptr) so the last dimension is 1 instead of
total_l — i.e. use (32, 4, m // 128, 4, interm_size // (scaling_vector_size *
4), 1) with the same order=(2, 1, 4, 0, 3, 5) so the tensor’s L dimension
remains 1 while keeping the same ordering.

Comment on lines +418 to 439
# Normalize to lists for multi-B support
b_list = [b] if isinstance(b, torch.Tensor) else list(b)
b_scale_list = [b_scale] if isinstance(b_scale, torch.Tensor) else list(b_scale)
alpha_list = [alpha] if isinstance(alpha, torch.Tensor) else list(alpha)

# Validate multi-B inputs
assert len(b_list) > 0, "Weight tensor list must not be empty"
assert len(b_list) <= 4, f"Maximum 4 weight tensors supported, got {len(b_list)}"
assert len(b_list) == len(b_scale_list) == len(alpha_list), (
f"b, b_scale, alpha lists must have same length: "
f"{len(b_list)}, {len(b_scale_list)}, {len(alpha_list)}"
)

# Validate inputs
assert a.device.type == "cuda", "Input tensors must be on CUDA device"
assert b.device.type == "cuda", "Input tensors must be on CUDA device"
assert b_list[0].device.type == "cuda", "Input tensors must be on CUDA device"

# Get dimensions
seq_len = a.shape[0]
num_experts = b.shape[0]
n = b.shape[1] # This is 2*intermediate_size
num_experts = sum(bi.size(0) for bi in b_list)
n = b_list[0].shape[1] # This is 2*intermediate_size
k = a.shape[1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate each split tensor, not just the list length.

The new checks still allow invalid multi-B configurations: an empty first split, a later split on a different device, a later split with different (N, K), or an alpha[i] whose length does not match b[i].size(0). Those all flow into the kernel, which assumes split 0 exists and reuses the first split's shape for every other split.

Proposed fix
-    assert len(b_list) > 0, "Weight tensor list must not be empty"
-    assert len(b_list) <= 4, f"Maximum 4 weight tensors supported, got {len(b_list)}"
-    assert len(b_list) == len(b_scale_list) == len(alpha_list), (
-        f"b, b_scale, alpha lists must have same length: "
-        f"{len(b_list)}, {len(b_scale_list)}, {len(alpha_list)}"
-    )
+    if len(b_list) == 0:
+        raise ValueError("Weight tensor list must not be empty")
+    if len(b_list) > 4:
+        raise ValueError(f"Maximum 4 weight tensors supported, got {len(b_list)}")
+    if len(b_list) != len(b_scale_list) or len(b_list) != len(alpha_list):
+        raise ValueError(
+            "b, b_scale, and alpha must contain the same number of splits"
+        )
+
+    ref_nk = b_list[0].shape[1:]
+    for i, (bi, bsi, ai) in enumerate(zip(b_list, b_scale_list, alpha_list)):
+        if bi.size(0) == 0:
+            raise ValueError(f"b[{i}] must contain at least one expert")
+        if bi.device != a.device or bsi.device != a.device or ai.device != a.device:
+            raise ValueError(
+                f"All split tensors must be on {a.device}; "
+                f"got b[{i}]={bi.device}, b_scale[{i}]={bsi.device}, alpha[{i}]={ai.device}"
+            )
+        if bi.shape[1:] != ref_nk:
+            raise ValueError(
+                f"All B splits must share the same (N, K); "
+                f"expected {ref_nk}, got {bi.shape[1:]} for b[{i}]"
+            )
+        if ai.numel() != bi.size(0):
+            raise ValueError(
+                f"alpha[{i}] must have {bi.size(0)} entries, got {ai.numel()}"
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`
around lines 418 - 439, Loop over each split in b_list (using enumerate) and
validate per-split properties instead of only checking the first element: assert
each bi (from b_list) is non-empty, on CUDA, has the same second-dimension as
b_list[0] (bi.shape[1] == n), and that bi.size(0) matches the length of the
corresponding alpha_list[i]; also assert corresponding b_scale_list[i] is
present and on CUDA. Update validations around b_list, b_scale_list, alpha_list,
a, n and k (references: b_list, b_scale_list, alpha_list, a, n, k, num_experts)
so the kernel never receives an empty split, a device-mismatched tensor, or
mismatched per-split dimensions.

@nv-yunzheq
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !550 has been created, and the CI pipeline #48542959 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@nv-yunzheq nv-yunzheq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contribution to the project!
I think we want to know if there is a performance regression with the non-data-parallelism case.
Moreover, the autotuner design might need more consideration. Right now it depends on the number of experts as before. This might provide sub-optimal performance on data parallelism case. Do we need to also consider keep the number of parallelism as an parameter as well?

# as a tuple (even for single-B, e.g. (256,)).
if b_tensor_l_sizes is None:
raise ValueError(
"b_tensor_l_sizes is required. Pass a tuple with the number of "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the description of b_tensor_l_sizes between here (required) and the function signature (optional) is different.

topk: cutlass.Int64,
raster_along_m: bool = False,
enable_pdl: bool = True,
b_tensor_l_sizes: Optional[Tuple[int, ...]] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer to put the parameter of enalbe_pdl as the last parameter, as TensorRT-LLM doesn't contain such parameter. It would be easier for developer when porting the optimization

scale_k = k // scaling_vector_size
interm_size = n // 2
num_tiles = m // tile_size
total_l = self.b_tensor_l_offsets[self.num_b_tensors]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

…_l_sizes is None

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py (1)

420-440: ⚠️ Potential issue | 🟠 Major

Validate every split before deriving shapes and tuple pointers.

This still only checks list counts and b_list[0].device. An empty first split, a later split on a different device or with a different (N, K), or an alpha[i] whose length does not match b[i].size(0) will flow into the kernel, which uses split 0 for shape/layout setup and reads alpha_tuple[0][0] unconditionally in the epilogue.

Proposed fix
-    assert len(b_list) > 0, "Weight tensor list must not be empty"
-    assert len(b_list) <= 4, f"Maximum 4 weight tensors supported, got {len(b_list)}"
-    assert len(b_list) == len(b_scale_list) == len(alpha_list), (
-        f"b, b_scale, alpha lists must have same length: "
-        f"{len(b_list)}, {len(b_scale_list)}, {len(alpha_list)}"
-    )
+    if len(b_list) == 0:
+        raise ValueError("Weight tensor list must not be empty")
+    if len(b_list) > 4:
+        raise ValueError(f"Maximum 4 weight tensors supported, got {len(b_list)}")
+    if len(b_list) != len(b_scale_list) or len(b_list) != len(alpha_list):
+        raise ValueError(
+            "b, b_scale, and alpha must contain the same number of splits"
+        )
+
+    ref_shape = b_list[0].shape[1:]
+    for i, (bi, bsi, ai) in enumerate(zip(b_list, b_scale_list, alpha_list)):
+        if bi.size(0) == 0:
+            raise ValueError(f"b[{i}] must contain at least one expert")
+        if bi.device != a.device or bsi.device != a.device or ai.device != a.device:
+            raise ValueError(
+                f"Split {i} must be on {a.device}: "
+                f"b={bi.device}, b_scale={bsi.device}, alpha={ai.device}"
+            )
+        if bi.shape[1:] != ref_shape:
+            raise ValueError(
+                f"All B splits must share the same trailing shape; "
+                f"expected {ref_shape}, got {bi.shape[1:]} for b[{i}]"
+            )
+        if ai.numel() != bi.size(0):
+            raise ValueError(
+                f"alpha[{i}] must have {bi.size(0)} entries, got {ai.numel()}"
+            )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`
around lines 420 - 440, The code currently only checks counts and b_list[0], but
you must validate every split before using split-0 for shapes/tuples: iterate i
over b_list and assert each b_list[i] is non-empty (b_list[i].size(0) > 0), on
the same device as a (b_list[i].device == a.device), has the same column
dimension as the first split (b_list[i].shape[1] == b_list[0].shape[1]), and
that the corresponding alpha_list[i] length matches the number of experts in
that split (if alpha_list[i] is tensor, alpha_list[i].shape[0] ==
b_list[i].size(0); if list, len(alpha_list[i]) == b_list[i].size(0)); also
validate b_scale_list[i] if it represents per-expert scales. Only after these
per-split validations compute num_experts and derive n, and ensure any later use
like alpha_tuple[0][0] is safe because every split was checked.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`:
- Around line 833-840: Before any const_expr branches that index b_tuple[1..3],
sfb_tuple[1..3], or alpha_tuple[1..3], validate that the passed-in tuples have
at least self.num_b_tensors elements and raise a clear input error if not;
specifically, add an arity check just after constructing b_tuple, sfb_tuple (and
alpha_tuple where created) that compares len(b_tuple), len(sfb_tuple), and
len(alpha_tuple) against self.num_b_tensors and raises a ValueError (or
appropriate exception) with a descriptive message mentioning
b_tuple/sfb_tuple/alpha_tuple and expected vs actual counts so downstream
indexing in the const_expr branches (and related code paths around the methods
using self.num_b_tensors) cannot produce internal index failures.

---

Duplicate comments:
In
`@flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`:
- Around line 420-440: The code currently only checks counts and b_list[0], but
you must validate every split before using split-0 for shapes/tuples: iterate i
over b_list and assert each b_list[i] is non-empty (b_list[i].size(0) > 0), on
the same device as a (b_list[i].device == a.device), has the same column
dimension as the first split (b_list[i].shape[1] == b_list[0].shape[1]), and
that the corresponding alpha_list[i] length matches the number of experts in
that split (if alpha_list[i] is tensor, alpha_list[i].shape[0] ==
b_list[i].size(0); if list, len(alpha_list[i]) == b_list[i].size(0)); also
validate b_scale_list[i] if it represents per-expert scales. Only after these
per-split validations compute num_experts and derive n, and ensure any later use
like alpha_tuple[0][0] is safe because every split was checked.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7c2869ff-4f42-4d61-99c3-366698f1dea9

📥 Commits

Reviewing files that changed from the base of the PR and between 9edbe9b and 5054cb1.

📒 Files selected for processing (2)
  • flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py
  • flashinfer/fused_moe/cute_dsl/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py

Comment on lines +833 to +840
# Handle tuple of B tensors
b_tuple = b if isinstance(b, tuple) else (b,)
sfb_tuple = sfb if isinstance(sfb, tuple) else (sfb,)
self.b_dtype: Type[cutlass.Numeric] = b_tuple[0].element_type
self.c_dtype: Type[cutlass.Numeric] = c.element_type
self.sf_dtype: Type[cutlass.Numeric] = sfa.element_type
self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode()
self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode()
self.b_major_mode = utils.LayoutEnum.from_tensor(b_tuple[0]).mma_major_mode()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Check tuple arity against self.num_b_tensors before the const_expr branches.

These paths index b_tuple[1..3], sfb_tuple[1..3], and alpha_tuple[1..3] based only on self.num_b_tensors. A direct caller that passes too few tensors gets an internal index failure instead of a clear input error.

Proposed fix
         b_tuple = b if isinstance(b, tuple) else (b,)
         sfb_tuple = sfb if isinstance(sfb, tuple) else (sfb,)
+        alpha_tuple = alpha if isinstance(alpha, tuple) else (alpha,)
+        if (
+            len(b_tuple) != self.num_b_tensors
+            or len(sfb_tuple) != self.num_b_tensors
+            or len(alpha_tuple) != self.num_b_tensors
+        ):
+            raise ValueError(
+                f"Expected {self.num_b_tensors} B/SFB/alpha tensors, got "
+                f"{len(b_tuple)}, {len(sfb_tuple)}, {len(alpha_tuple)}"
+            )
         self.b_dtype: Type[cutlass.Numeric] = b_tuple[0].element_type
         self.c_dtype: Type[cutlass.Numeric] = c.element_type
         self.sf_dtype: Type[cutlass.Numeric] = sfa.element_type
@@
-        alpha_tuple = alpha if isinstance(alpha, tuple) else (alpha,)

Also applies to: 998-999

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@flashinfer/fused_moe/cute_dsl/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py`
around lines 833 - 840, Before any const_expr branches that index b_tuple[1..3],
sfb_tuple[1..3], or alpha_tuple[1..3], validate that the passed-in tuples have
at least self.num_b_tensors elements and raise a clear input error if not;
specifically, add an arity check just after constructing b_tuple, sfb_tuple (and
alpha_tuple where created) that compares len(b_tuple), len(sfb_tuple), and
len(alpha_tuple) against self.num_b_tensors and raises a ValueError (or
appropriate exception) with a descriptive message mentioning
b_tuple/sfb_tuple/alpha_tuple and expected vs actual counts so downstream
indexing in the const_expr branches (and related code paths around the methods
using self.num_b_tensors) cannot produce internal index failures.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants